# Drug Recommendation using MICRON on MIMIC-III Dataset

import os
import numpy as np
import pandas as pd
import torch
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.models import MICRON
from pyhealth.trainer import Trainer
# from pyhealth.metrics import multilabel_metrics # Removed this import

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

"""## 3. Load and Process MIMIC-III Dataset

We'll load the MIMIC-III dataset using PyHealth's built-in dataset loader and prepare it for training. The dataset will be processed to include patient diagnoses, procedures, and medications.
"""

# Configuration
#MIMIC3_PATH = "https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III"
MIMIC3_PATH = "https://physionet.org/files/mimiciii-demo/1.4/"    #update this dataset path to your environment

# Load MIMIC-III dataset
dataset = MIMIC3Dataset(
    root=MIMIC3_PATH,
    tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
    dev=True
)
dataset.stats()

"""## 3. Set Drug Recommendation Task

Use the DrugRecommendationMIMIC3 task function which creates samples with conditions, procedures, and atc-3 codes (drugs).
"""

from pyhealth.tasks import DrugRecommendationMIMIC3

task = DrugRecommendationMIMIC3()
samples = dataset.set_task(task, num_workers=4)

print(f"Sample Dataset Statistics:")
print(f"\t- Dataset: {samples.dataset_name}")
print(f"\t- Task: {samples.task_name}")
print(f"\t- Number of samples: {len(samples)}")

print("\nFirst sample structure:")
print(f"Patient ID: {samples.samples[0]['patient_id']}")
print(f"Number of visits: {len(samples.samples[0]['conditions'])}")
print(f"Sample conditions (first visit): {samples.samples[0]['conditions'][0][:5]}...")
print(f"Sample procedures (first visit): {samples.samples[0]['procedures'][0][:5]}...")
print(f"Sample drugs (target): {samples.samples[0]['drugs'][:10]}...")

"""## Split Dataset and Create Data Loaders"""

from pyhealth.datasets import split_by_patient, get_dataloader

train_dataset, val_dataset, test_dataset = split_by_patient(
    samples, ratios=[0.7, 0.1, 0.2]
)

print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

"""## 4. Initialize and Configure MICRON Model

Now we'll set up the MICRON model with appropriate hyperparameters for drug recommendation.
"""

# Model hyperparameters
model_params = {
    "embedding_dim": 128,
    "hidden_dim": 128,
    "lam": 0.1  # Regularization parameter for reconstruction loss
}

# Initialize MICRON model
model = MICRON(
    dataset=samples,
    **model_params
).to(DEVICE)

print(model)


# Configure trainer
trainer = Trainer(
    model=model,
    #metrics=["jaccard_samples", "f1_samples", "pr_auc_samples"]
    metrics=["jaccard_samples", "f1_samples", "pr_auc_samples", "ddi"]
)

print("Baseline performance before training:")
baseline_results = trainer.evaluate(test_dataloader)
print(baseline_results)

"""## 5. Train the Model

Let's train the MICRON model on our processed MIMIC-III dataset.
"""

# Train the model
history = trainer.train(
    train_dataloader, # Pass as positional argument
    val_dataloader,   # Pass as positional argument
    epochs=5,
    monitor="pr_auc_samples"
)

# Save the trained model
#torch.save(model.state_dict(), "/content/drive/MyDrive/micron_model.pt")

"""## 6. Evaluate Model Performance

Finally, let's evaluate our trained model on the test set and visualize the results.
"""

test_results = trainer.evaluate(test_dataloader)
print("Final test set performance:")
print(test_results)

print(f"\nKey Metrics:")
print(f"  PR-AUC: {test_results.get('pr_auc_samples', 'N/A'):.4f}")
print(f"  F1 Score: {test_results.get('f1_samples', 'N/A'):.4f}")
print(f"  Jaccard: {test_results.get('jaccard_samples', 'N/A'):.4f}")
print(f"  DDI Rate: {test_results.get('ddi_score', 'N/A'):.4f} (lower is better)")

# before training
# {'jaccard_samples': 0.14468260340523165, 'f1_samples': 0.25038132615979186, 'pr_auc_samples': 0.16410485822712173, 'ddi_score': 0.0, 'loss': 26.26311683654785}

# Post training
# {'jaccard_samples': 0.20728042582268005, 'f1_samples': 0.33950176113659936, 'pr_auc_samples': 0.27233906976425537, 'ddi_score': 0.0, 'loss': 4.690524578094482}

"""## Conclusion

We have successfully implemented and trained a MICRON model for drug recommendation using the MIMIC-III dataset. The model's performance can be evaluated using the metrics above:

1. DDI score: Shows the drug-drug interaction score indicating the effect of one drug on another. lower is better.
2. Precision: Indicates how many of the predicted drugs were actually correct
3. Recall: Shows how many of the actual drugs were correctly predicted
4. F1 Score: The harmonic mean of precision and recall

The confusion matrix visualization helps us understand where the model performs well and where it might need improvement. The training loss plot shows how the model learned over time.

Next steps could include:
- Hyperparameter tuning to improve performance
- Testing with different model architectures
- Analyzing specific cases where the model performs well or poorly
- Incorporating additional patient features
"""